op {
  graph_op_name: "AllToAll"
  visibility: HIDDEN
  in_arg {
    name: "input"
    description: <<END
The local input to the sum.
END
  }
  in_arg {
    name: "group_assignment"
    description: <<END
An int32 tensor with shape
[num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
replica ids in the ith subgroup.
END
  }
  out_arg {
    name: "output"
    description: <<END
The exchanged result.
END
  }
  attr {
    name: "T"
    description: <<END
The type of elements to be exchanged.
END
  }
  attr {
    name: "concat_dimension"
    description: <<END
The dimension number to concatenate.
END
  }
  attr {
    name: "split_dimension"
    description: <<END
The dimension number to split.
END
  }
  attr {
    name: "split_count"
    description: <<END
The number of splits, this number must equal to the sub-group
size(group_assignment.get_shape()[1])
END
  }
  summary: "An Op to exchange data across TPU replicas."
  description: <<END
On each replica, the input is split into `split_count` blocks along
`split_dimension` and send to the other replicas given group_assignment. After
receiving `split_count` - 1 blocks from other replicas, we concatenate the
blocks along `concat_dimension` as the output.

For example, suppose there are 2 TPU replicas:
replica 0 receives input: `[[A, B]]`
replica 1 receives input: `[[C, D]]`

group_assignment=`[[0, 1]]`
concat_dimension=0
split_dimension=1
split_count=2

replica 0's output: `[[A], [C]]`
replica 1's output: `[[B], [D]]`
END
}
